"""Hand Retargeting

Simpler shadow hand retargeting example.
Find and unzip the shadowhand URDF at `assets/hand_retargeting/shadowhand_urdf.zip`.
"""

import pickle
import time
from pathlib import Path
from typing import Tuple, TypedDict

import jax
import jax.numpy as jnp
import jax_dataclasses as jdc
import jaxlie
import jaxls
import numpy as onp
import pyroki as pk
import trimesh
import viser
import yourdfpy
from scipy.spatial.transform import Rotation as R
from viser.extras import ViserUrdf

from retarget_helpers._utils import (
    MANO_TO_SHADOW_MAPPING,
    create_conn_tree,
    get_mapping_from_mano_to_shadow,
)


class RetargetingWeights(TypedDict):
    local_alignment: float
    """Local alignment weight, by matching the relative joint/keypoint positions and angles."""
    global_alignment: float
    """Global alignment weight, by matching the keypoint positions to the robot."""
    joint_smoothness: float
    """Joint smoothness weight."""
    root_smoothness: float
    """Root translation smoothness weight."""


def main():
    """Main function for hand retargeting."""

    asset_dir = Path(__file__).parent / "retarget_helpers" / "hand"

    robot_urdf_path = asset_dir / "shadowhand_urdf" / "shadow_hand_right.urdf"

    def filename_handler(fname: str) -> str:
        base_path = robot_urdf_path.parent
        return yourdfpy.filename_handler_magic(fname, dir=base_path)

    try:
        urdf = yourdfpy.URDF.load(robot_urdf_path, filename_handler=filename_handler)
    except FileNotFoundError:
        raise FileNotFoundError(
            "Please unzip the included URDF at `retarget_helpers/hand/shadowhand_urdf.zip`."
        )

    robot = pk.Robot.from_urdf(urdf)

    # Get the mapping from MANO to Shadow Hand joints.
    shadow_link_idx, mano_joint_idx = get_mapping_from_mano_to_shadow(robot)

    # Create a mask for the MANO joints that are connected to the Shadow Hand.
    mano_mask = create_conn_tree(robot, shadow_link_idx)

    # Load source motion data.
    dexycb_motion_path = asset_dir / "dexycb_motion.pkl"
    with open(dexycb_motion_path, "rb") as f:
        dexycb_motion_data = pickle.load(f, encoding="latin1")

    # Load keypoints.
    keypoints = dexycb_motion_data["world_hand_joints"]
    assert not onp.isnan(keypoints).any()
    num_timesteps = keypoints.shape[0]
    num_mano_joints = len(MANO_TO_SHADOW_MAPPING)

    # Load mano hand contact information -- these are lists of lists,
    # len(contact_points_per_frame) = num_timesteps,
    # len(contact_points_per_frame[i]) = number of contacts in frame i,
    contact_points_per_frame = dexycb_motion_data["contact_object_points"]
    contact_indices_per_frame = dexycb_motion_data["contact_joint_indices"]

    # Now, we're going to pad this info + make a mask to indicate the padded regions.
    # We will also track the shadowhand joint indices, NOT the MANO joint indices.
    max_num_contacts = max(len(c) for c in contact_points_per_frame)
    padded_contact_points_per_frame = onp.zeros((num_timesteps, max_num_contacts, 3))
    padded_contact_indices_per_frame = onp.zeros(
        (num_timesteps, max_num_contacts), dtype=onp.int32
    )
    padded_contact_mask = onp.zeros((num_timesteps, max_num_contacts), dtype=onp.bool_)
    for i in range(num_timesteps):
        num_contacts = len(contact_points_per_frame[i])
        if num_contacts == 0:
            continue
        contact_shadowhand_indices = [
            robot.links.names.index(MANO_TO_SHADOW_MAPPING[j])
            for j in contact_indices_per_frame[i]
        ]
        padded_contact_points_per_frame[i, :num_contacts] = contact_points_per_frame[i]
        padded_contact_indices_per_frame[i, :num_contacts] = contact_shadowhand_indices
        padded_contact_mask[i, :num_contacts] = True

    # Load the object.
    object_mesh_vertices = dexycb_motion_data["object_mesh_vertices"]
    object_mesh_faces = dexycb_motion_data["object_mesh_faces"]
    object_pose_list = dexycb_motion_data["object_poses"]  # (N, 4, 4)
    mesh = trimesh.Trimesh(object_mesh_vertices, object_mesh_faces)

    server = viser.ViserServer()

    # We will transform everything by the transform below, for aesthetics.
    server.scene.add_frame(
        "/scene_offset",
        show_axes=False,
        position=(-0.15415953, -0.73598871, 0.93434792),
        wxyz=(-0.381870867, 0.92421569, 0.0, 2.0004992e-32),
    )
    hand_mesh = server.scene.add_mesh_simple(
        "/scene_offset/hand_mesh",
        vertices=dexycb_motion_data["world_hand_vertices"][0, :, :],
        faces=dexycb_motion_data["hand_mesh_faces"],
        opacity=0.5,
    )
    base_frame = server.scene.add_frame("/scene_offset/base", show_axes=False)
    urdf_vis = ViserUrdf(server, urdf, root_node_name="/scene_offset/base")
    playing = server.gui.add_checkbox("playing", True)
    timestep_slider = server.gui.add_slider("timestep", 0, num_timesteps - 1, 1, 0)
    object_handle = server.scene.add_mesh_trimesh("/scene_offset/object", mesh)
    server.scene.add_grid("/grid", 2.0, 2.0)

    default_weights = RetargetingWeights(
        local_alignment=10.0,
        global_alignment=1.0,
        joint_smoothness=2.0,
        root_smoothness=2.0,
    )

    weights = pk.viewer.WeightTuner(
        server,
        default_weights,  # type: ignore
    )

    Ts_world_root, joints = None, None

    def generate_trajectory():
        nonlocal Ts_world_root, joints
        gen_button.disabled = True
        Ts_world_root, joints = solve_retargeting(
            robot=robot,
            target_keypoints=keypoints,
            shadow_hand_link_retarget_indices=shadow_link_idx,
            mano_joint_retarget_indices=mano_joint_idx,
            mano_mask=mano_mask,
            weights=weights.get_weights(),  # type: ignore
        )
        gen_button.disabled = False

    gen_button = server.gui.add_button("Retarget!")
    gen_button.on_click(lambda _: generate_trajectory())

    generate_trajectory()
    assert Ts_world_root is not None and joints is not None

    while True:
        with server.atomic():
            if playing.value:
                timestep_slider.value = (timestep_slider.value + 1) % num_timesteps
            tstep = timestep_slider.value
            base_frame.wxyz = onp.array(Ts_world_root.wxyz_xyz[tstep][:4])
            base_frame.position = onp.array(Ts_world_root.wxyz_xyz[tstep][4:])
            urdf_vis.update_cfg(onp.array(joints[tstep]))

            server.scene.add_point_cloud(
                "/scene_offset/target_keypoints",
                onp.array(keypoints[tstep]).reshape(-1, 3),
                onp.array((0, 0, 255))[None]
                .repeat(num_mano_joints, axis=0)
                .reshape(-1, 3),
                point_size=0.005,
                point_shape="sparkle",
            )
            server.scene.add_point_cloud(
                "/scene_offset/contact_points",
                onp.array(contact_points_per_frame[tstep]).reshape(-1, 3),
                onp.array((255, 0, 0))[None]
                .repeat(len(contact_points_per_frame[tstep]), axis=0)
                .reshape(-1, 3),
                point_size=0.005,
                point_shape="circle",
            )
            hand_mesh.vertices = dexycb_motion_data["world_hand_vertices"][tstep, :, :]
            object_handle.position = object_pose_list[tstep][:3, 3]
            object_handle.wxyz = R.from_matrix(object_pose_list[tstep][:3, :3]).as_quat(
                scalar_first=True
            )

        time.sleep(0.05)


@jdc.jit
def solve_retargeting(
    robot: pk.Robot,
    target_keypoints: jnp.ndarray,
    shadow_hand_link_retarget_indices: jnp.ndarray,
    mano_joint_retarget_indices: jnp.ndarray,
    mano_mask: jnp.ndarray,
    weights: RetargetingWeights,
) -> Tuple[jaxlie.SE3, jnp.ndarray]:
    """Solve the retargeting problem."""

    n_retarget = len(mano_joint_retarget_indices)
    timesteps = target_keypoints.shape[0]

    # Variables.
    class ManoJointsScaleVar(
        jaxls.Var[jax.Array], default_factory=lambda: jnp.ones((n_retarget, n_retarget))
    ): ...

    class OffsetVar(jaxls.Var[jax.Array], default_factory=lambda: jnp.zeros((3,))): ...

    var_joints = robot.joint_var_cls(jnp.arange(timesteps))
    var_Ts_world_root = jaxls.SE3Var(jnp.arange(timesteps))
    var_smpl_joints_scale = ManoJointsScaleVar(jnp.zeros(timesteps))
    var_offset = OffsetVar(jnp.zeros(timesteps))

    # Costs.
    costs: list[jaxls.Cost] = []

    @jaxls.Cost.create_factory
    def retargeting_cost(
        var_values: jaxls.VarValues,
        var_Ts_world_root: jaxls.SE3Var,
        var_robot_cfg: jaxls.Var[jnp.ndarray],
        var_smpl_joints_scale: ManoJointsScaleVar,
        keypoints: jnp.ndarray,
    ) -> jax.Array:
        """Retargeting factor, with a focus on:
        - matching the relative joint/keypoint positions (vectors).
        - and matching the relative angles between the vectors.
        """
        robot_cfg = var_values[var_robot_cfg]
        T_root_link = jaxlie.SE3(robot.forward_kinematics(cfg=robot_cfg))
        T_world_root = var_values[var_Ts_world_root]
        T_world_link = T_world_root @ T_root_link

        mano_pos = keypoints[jnp.array(mano_joint_retarget_indices)]
        robot_pos = T_world_link.translation()[
            jnp.array(shadow_hand_link_retarget_indices)
        ]

        # NxN grid of relative positions.
        delta_mano = mano_pos[:, None] - mano_pos[None, :]
        delta_robot = robot_pos[:, None] - robot_pos[None, :]

        # Vector regularization.
        position_scale = var_values[var_smpl_joints_scale][..., None]
        residual_position_delta = (
            (delta_mano - delta_robot * position_scale)
            * (1 - jnp.eye(delta_mano.shape[0])[..., None])
            * mano_mask[..., None]
        )

        # Vector angle regularization.
        delta_mano_normalized = delta_mano / jnp.linalg.norm(
            delta_mano + 1e-6, axis=-1, keepdims=True
        )
        delta_robot_normalized = delta_robot / jnp.linalg.norm(
            delta_robot + 1e-6, axis=-1, keepdims=True
        )
        residual_angle_delta = 1 - (delta_mano_normalized * delta_robot_normalized).sum(
            axis=-1
        )
        residual_angle_delta = (
            residual_angle_delta
            * (1 - jnp.eye(residual_angle_delta.shape[0]))
            * mano_mask
        )

        residual = (
            jnp.concatenate(
                [
                    residual_position_delta.flatten(),
                    residual_angle_delta.flatten(),
                ],
                axis=0,
            )
            * weights["local_alignment"]
        )
        return residual

    @jaxls.Cost.create_factory
    def pc_alignment_cost(
        var_values: jaxls.VarValues,
        var_Ts_world_root: jaxls.SE3Var,
        var_robot_cfg: jaxls.Var[jnp.ndarray],
        keypoints: jnp.ndarray,
    ) -> jax.Array:
        """Soft cost to align the human keypoints to the robot, in the world frame."""
        T_world_root = var_values[var_Ts_world_root]
        robot_cfg = var_values[var_robot_cfg]
        T_root_link = jaxlie.SE3(robot.forward_kinematics(cfg=robot_cfg))
        T_world_link = T_world_root @ T_root_link
        link_pos = T_world_link.translation()[shadow_hand_link_retarget_indices]
        keypoint_pos = keypoints[mano_joint_retarget_indices]
        return (link_pos - keypoint_pos).flatten() * weights["global_alignment"]

    @jaxls.Cost.create_factory
    def root_smoothness(
        var_values: jaxls.VarValues,
        var_Ts_world_root: jaxls.SE3Var,
        var_Ts_world_root_prev: jaxls.SE3Var,
    ) -> jax.Array:
        """Smoothness cost for the robot root translation."""
        return (
            var_values[var_Ts_world_root].translation()
            - var_values[var_Ts_world_root_prev].translation()
        ).flatten() * weights["root_smoothness"]

    costs = [
        retargeting_cost(
            var_Ts_world_root,
            var_joints,
            var_smpl_joints_scale,
            target_keypoints,
        ),
        pk.costs.limit_cost(
            jax.tree.map(lambda x: x[None], robot),
            var_joints,
            100.0,
        ),
        pk.costs.smoothness_cost(
            robot.joint_var_cls(jnp.arange(1, timesteps)),
            robot.joint_var_cls(jnp.arange(0, timesteps - 1)),
            jnp.array([weights["joint_smoothness"]]),
        ),
        pc_alignment_cost(
            var_Ts_world_root,
            var_joints,
            target_keypoints,
        ),
        root_smoothness(
            jaxls.SE3Var(jnp.arange(1, timesteps)),
            jaxls.SE3Var(jnp.arange(0, timesteps - 1)),
        ),
    ]

    solution = (
        jaxls.LeastSquaresProblem(
            costs, [var_joints, var_Ts_world_root, var_smpl_joints_scale, var_offset]
        )
        .analyze()
        .solve()
    )
    transform = solution[var_Ts_world_root]
    offset = solution[var_offset]
    transform = jaxlie.SE3.from_translation(offset) @ transform
    return transform, solution[var_joints]


if __name__ == "__main__":
    main()
